import numpy as np
from gps.kde import kde_gps
from gps.cnf.normflow import CNF
import torch
from utils.kernel import e_kernel
from model.rkhs.rkhs_approx import ApproxRKHSIVCV


class RKHS_Trainer():
    def __init__(self, cfg, dataset):
        self.cfg = cfg
        self.dataset = dataset

    def fit_h_cv(self,):
        A,Z,W,Y = self.dataset.treatment, self.dataset.treatment_proxy, self.dataset.outcome_proxy, self.dataset.outcome
        X = self.dataset.backdoor
        AZX = np.concatenate([A, Z], axis=1)
        AWX = np.concatenate([A, W], axis=1)
        if X is not None:
            AZX = np.concatenate([AZX, X], axis=1)
            AWX = np.concatenate([AWX, X], axis=1)
        RKHS = ApproxRKHSIVCV(gamma_gm=self.cfg.rkhs.gamma_gm, n_gamma_hqs=self.cfg.rkhs.n_gamma_hqs, n_alphas=self.cfg.rkhs.n_alphas, alpha_scales=self.cfg.rkhs.alpha_scales, cv=self.cfg.rkhs.cv, n_components=self.cfg.rkhs.n_components)
        RKHS_h = RKHS.fit(AWX,Y,AZX, 'estimate_h')
        self.gamma_gm = RKHS_h.gamma_gm
        self.gamma_hq = RKHS_h.gamma_hq
        self.best_alpha_scale = RKHS_h.best_alpha_scale
        self.h = RKHS_h.predict
        return self
    
    def fit_q_cv(self,type = 'kde'):
        A,Z,W,Y = self.dataset.treatment, self.dataset.treatment_proxy, self.dataset.outcome_proxy, self.dataset.outcome
        X = self.dataset.backdoor
        AZX = np.concatenate([A, Z], axis=1)
        AWX = np.concatenate([A, W], axis=1)
        if X is not None:
            AZX = np.concatenate([AZX, X], axis=1)
            AWX = np.concatenate([AWX, X], axis=1)
        RKHS = ApproxRKHSIVCV(gamma_gm=self.cfg.rkhs.gamma_gm, n_gamma_hqs=self.cfg.rkhs.n_gamma_hqs, n_alphas=self.cfg.rkhs.n_alphas, alpha_scales=self.cfg.rkhs.alpha_scales, cv=self.cfg.rkhs.cv, n_components=self.cfg.rkhs.n_components)
        if type == 'kde':
            gps = kde_gps(A,W,X) 
        elif type == 'normflow_new':
            gps_train = CNF(DEVICE=self.cfg.density_5000.device, n_layers=self.cfg.density_5000.n_layers, hidden=self.cfg.density_5000.hidden,
                            batch_size=self.cfg.density_5000.batch_size, lr=self.cfg.density_5000.init_lr, n_epochs=self.cfg.density_5000.n_epochs, weight_decay=self.cfg.density_5000.weight_decay)
            WX =W
            if X is not None:
                WX = np.concatenate([WX, X], axis=1)
            WX = torch.tensor(WX, dtype=torch.float32)
            A_torch = torch.tensor(A, dtype=torch.float32)
            gps_train.fit(A_torch, WX)
            gps = (1/gps_train.pob(A_torch, WX)).cpu().detach().numpy()

        RKHS_q = RKHS.fit(AWX,gps,AZX, 'estimate_q')
        self.gamma_gm = RKHS_q.gamma_gm
        self.gamma_hq = RKHS_q.gamma_hq
        self.best_alpha_scale = RKHS_q.best_alpha_scale
        self.q = RKHS_q.predict
        return self
    
    def _htest(self,pointA,sampleTest) -> float:
        ATE_list = []
        W = sampleTest.outcome_proxy
        X = sampleTest.backdoor
        for a in pointA:
            a_full = np.full((len(W), 1), a)
            inp = np.concatenate([a_full, W], axis=1)
            if X is not None:
                inp = np.concatenate([inp, X], axis=1)
            ATE = np.mean(self.h(inp))
            ATE_list.append(ATE)
        return ATE_list
    
    def _qtest(self,pointA,sampleTest) -> float:
        A = sampleTest.treatment
        Z = sampleTest.treatment_proxy
        X = sampleTest.backdoor
        Y = sampleTest.outcome

        Y = Y.ravel() if Y.ndim == 2 else Y
        A = A.ravel() if A.ndim == 2 else A

    
        ATE_list= []
        bandwidth = 1.5*np.std(A)*(len(A)**-0.2)

        for a in pointA:
            a_full = np.full((len(Z), 1), a)
            inp = np.concatenate([a_full, Z], axis=1)
            if X is not None:
                inp = np.concatenate([inp, X], axis=1)
            
            q_azx = self.q(inp)  
            q_azx[q_azx < 0.01] = 0.01

            ATE = np.mean(e_kernel(A-a,bandwidth)*q_azx*Y)
           
            ATE_list.append(ATE)

        return ATE_list
    
    def _drtest(self,pointA,sampleTest) -> float:
        
        A = sampleTest.treatment
        W = sampleTest.outcome_proxy
        Z = sampleTest.treatment_proxy
        X = sampleTest.backdoor
        Y = sampleTest.outcome

        Y = Y.ravel() if Y.ndim == 2 else Y
        A = A.ravel() if A.ndim == 2 else A
        ATE_list = []
        bandwidth = 1.5*np.std(A)*(len(A)**-0.2)

        for a in pointA:
            a_full = np.full((len(W), 1), a)
            inp_h = np.concatenate([a_full, W], axis=1)
            inp_q = np.concatenate([a_full, Z], axis=1)
            if X is not None:
                inp_h = np.concatenate([inp_h, X], axis=1)
                inp_q = np.concatenate([inp_q, X], axis=1)
            
            ATE = np.mean((Y-self.h(inp_h))*self.q(inp_q)*e_kernel(A-a,bandwidth) + self.h(inp_h))
            ATE_list.append(ATE)
        return ATE_list
    
